package org.example.util;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Objects;
import java.util.Optional;

/**
 * 向量工具, 学CNN顺手写的工具类.
 *
 * @author KazuHo
 * @version 1.0
 * @since 2020-08-08
 */
public final class VectorUtil {

    /**
     * 将数组向量编码为base64字符串.
     *
     * @param vector 高维向量
     * @return base64向量串
     */
    public static String encodeFeature(float[] vector) {
        if (Objects.isNull(vector) || vector.length == 0) {
            throw new NullPointerException("invalid vector.");
        }

        ByteBuffer byteBuffer = ByteBuffer.allocate(vector.length * 4);
        byteBuffer.order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(vector);

        return new String(Base64.getEncoder().encode(byteBuffer).array(), StandardCharsets.UTF_8);
    }

    /**
     * 将base64向量串解码为数组向量.
     *
     * @param base64Vector base64向量
     * @return 数组向量
     */
    public static float[] decodeFeature(String base64Vector) {
        byte[] bytes =
                Optional.ofNullable(base64Vector)
                        .map(String::trim)
                        .map(b -> b.getBytes(StandardCharsets.UTF_8))
                        .map(b -> Base64.getDecoder().decode(b))
                        .orElseThrow(
                                () ->
                                        new RuntimeException(
                                                String.format("invalid base64 vector [%s] given.", base64Vector)));
        // 移形换影
        float[] floats = new float[bytes.length / 4];
        for (int i = 0; i < bytes.length; i++) {
            int b;
            b = bytes[i++];
            b &= 0xff;
            b |= ((long) bytes[i++] << 8);
            b &= 0xffff;
            b |= ((long) bytes[i++] << 16);
            b &= 0xffffff;
            b |= ((long) bytes[i] << 24);

            int idx = (i + 1) / 4 - 1;
            floats[idx] = Float.intBitsToFloat(b);
        }

        return floats;
    }

    /**
     * 计算一个空间向量相对于另一个空间向量的欧几里得距离 用于判断相似图, 越小则越接近
     *
     * @param src   源空间距离
     * @param other 需要比较的空间距离
     * @return 相似度
     */
    public static float calcEuclideanDist(float[] src, float[] other) {
        if (Objects.isNull(src) || Objects.isNull(other)) {
            throw new NullPointerException("invalid src or other vector given.");
        }

        double distance = 0;

        if (src.length == other.length) {
            for (int i = 0; i < src.length; i++) {
                double temp = Math.pow((src[i] - other[i]), 2);
                distance += temp;
            }
            distance = Math.sqrt(distance);
        } else {
            distance = Double.MAX_VALUE;
        }

        return (float) distance;
    }
}
